{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "test.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "toc_visible": true,
      "private_outputs": true,
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "metadata": {
        "colab_type": "text",
        "id": "Tce3stUlHN0L"
      },
      "cell_type": "markdown",
      "source": [
        "##### Copyright 2018 The TensorFlow Authors.\n",
        "\n"
      ]
    },
    {
      "metadata": {
        "colab_type": "code",
        "id": "tuOe1ymfHZPu",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "D6zc4Q6bxmHM",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "# Tensorflow Lite Gesture Classification Example Conversion Script\n",
        "\n",
        "\n",
        "This guide shows how you can go about converting the model trained with TensorFlowJS to TensorFlow Lite FlatBuffers.\n",
        "\n",
        "Run all steps in-order. At the end, `model.tflite` file will be downloaded.\n"
      ]
    },
    {
      "metadata": {
        "colab_type": "text",
        "id": "MfBg1C5NB3X0"
      },
      "cell_type": "markdown",
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/examples/blob/master/mobile/examples/gesture_classification/ml/tensorflowjs_to_tflite_colab_notebook.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />\n",
        "    Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/examples/blob/master/mobile/examples/gesture_classification/ml/tensorflowjs_to_tflite_colab_notebook.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />\n",
        "    View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "metadata": {
        "id": "yVMF3Q_HnJ09",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "**Install Dependencies**"
      ]
    },
    {
      "metadata": {
        "id": "FbMsNJ4PAq2j",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "!pip3 install tensorflow==1.11.0 keras==2.2.2 tensorflowjs==0.6.4 --force-reinstall"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "WZGFStffPj_Z",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "from __future__ import absolute_import, division, print_function, unicode_literals\n",
        "import traceback\n",
        "import logging\n",
        "import tensorflow as tf\n",
        "import keras.backend as K\n",
        "import os\n",
        "\n",
        "from google.colab import files\n",
        "\n",
        "from keras import Model, Input\n",
        "from keras.applications import MobileNet\n",
        "from keras.engine.saving import load_model\n",
        "\n",
        "from tensorflowjs.converters import load_keras_model\n",
        "\n",
        "logging.basicConfig(level=logging.INFO)\n",
        "logger = logging.getLogger(__name__)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "Lh7zgNXVx8ML",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "***Cleanup any existing models if necessary***"
      ]
    },
    {
      "metadata": {
        "id": "JrMA8frMx7aa",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "!rm -rf *.h5 *.tflite *.json *.bin"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "ct36DONNnNZJ",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "**Upload your Tensorflow.js Artifacts Here**\n",
        "\n",
        "i.e., The weights manifest **model.json** and the binary weights file **model-weights.bin**"
      ]
    },
    {
      "metadata": {
        "id": "s-_80hGtMTFb",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "files.upload()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "_ctAZ--FnStM",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "**Export Configuration**"
      ]
    },
    {
      "metadata": {
        "id": "hrzzoZP5oK7S",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "#@title Export Configuration\n",
        "\n",
        "# TensorFlow.js arguments\n",
        "\n",
        "config_json = \"model.json\" #@param {type:\"string\"}\n",
        "weights_path_prefix = None #@param {type:\"raw\"}\n",
        "model_tflite = \"model.tflite\" #@param {type:\"string\"}\n",
        "\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "RA0iINpNiK_p",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "**Model Converter**\n",
        "\n",
        "The following class converts a TensorFlow.js model to a TFLite FlatBuffer"
      ]
    },
    {
      "metadata": {
        "id": "8QMjVgxVggQJ",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "class ModelConverter:\n",
        "    \"\"\"\n",
        "    Creates a ModelConverter class from a TensorFlow.js model file.\n",
        "\n",
        "    Args:\n",
        "      :param config_json_path: Full filepath of weights manifest file containing the model architecture.\n",
        "      :param weights_path_prefix: Full filepath to the directory in which the weights binaries exist.\n",
        "      :param tflite_model_file: Name of the TFLite FlatBuffer file to be exported.\n",
        "\n",
        "    :return:\n",
        "      ModelConverter class.\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(self,\n",
        "                 config_json_path,\n",
        "                 weights_path_prefix,\n",
        "                 tflite_model_file\n",
        "                 ):\n",
        "        self.config_json_path = config_json_path\n",
        "        self.weights_path_prefix = weights_path_prefix\n",
        "        self.tflite_model_file = tflite_model_file\n",
        "        self.keras_model_file = 'merged.h5'\n",
        "\n",
        "        # MobileNet Options\n",
        "        self.input_node_name = 'the_input'\n",
        "        self.image_size = 224\n",
        "        self.alpha = 0.25\n",
        "        self.depth_multiplier = 1\n",
        "        self._input_shape = (1, self.image_size, self.image_size, 3)\n",
        "        self.depthwise_conv_layer = 'conv_pw_13_relu'\n",
        "\n",
        "    def convert(self):\n",
        "        self.save_keras_model()\n",
        "        self._deserialize_tflite_from_keras()\n",
        "        logger.info('The TFLite model has been generated')\n",
        "        self._purge()\n",
        "\n",
        "    def save_keras_model(self):\n",
        "        top_model = load_keras_model(self.config_json_path, self.weights_path_prefix,\n",
        "                                     weights_data_buffers=None,\n",
        "                                     load_weights=True,\n",
        "                                     use_unique_name_scope=True)\n",
        "\n",
        "        base_model = self.get_base_model()\n",
        "        merged_model = self.merge(base_model, top_model)\n",
        "        merged_model.save(self.keras_model_file)\n",
        "\n",
        "        logger.info(\"The merged Keras HDF5 model has been saved as {}\".format(self.keras_model_file))\n",
        "\n",
        "    def merge(self, base_model, top_model):\n",
        "        \"\"\"\n",
        "        Merges base model with the classification block\n",
        "        :return:  Returns the merged Keras model\n",
        "        \"\"\"\n",
        "        logger.info(\"Initializing model...\")\n",
        "\n",
        "        layer = base_model.get_layer(self.depthwise_conv_layer)\n",
        "        model = Model(inputs=base_model.input, outputs=top_model(layer.output))\n",
        "        logger.info(\"Model created.\")\n",
        "\n",
        "        return model\n",
        "\n",
        "    def get_base_model(self):\n",
        "        \"\"\"\n",
        "        Builds MobileNet with the default parameters\n",
        "        :return:  Returns the base MobileNet model\n",
        "        \"\"\"\n",
        "        input_tensor = Input(shape=self._input_shape[1:], name=self.input_node_name)\n",
        "        base_model = MobileNet(input_shape=self._input_shape[1:],\n",
        "                               alpha=self.alpha,\n",
        "                               depth_multiplier=self.depth_multiplier,\n",
        "                               input_tensor=input_tensor,\n",
        "                               include_top=False)\n",
        "        return base_model\n",
        "\n",
        "    def _deserialize_tflite_from_keras(self):\n",
        "        converter = tf.compat.v1.lite.TFLiteConverter.from_keras_model_file(self.keras_model_file)\n",
        "        tflite_model = converter.convert()\n",
        "\n",
        "        with open(self.tflite_model_file, \"wb\") as file:\n",
        "            file.write(tflite_model)\n",
        "\n",
        "    def _purge(self):\n",
        "        logger.info('Cleaning up Keras model')\n",
        "        os.remove(self.keras_model_file)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "qUeoHM-Jg7uv",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "try:\n",
        "    K.clear_session()\n",
        "    converter = ModelConverter(config_json,\n",
        "                               weights_path_prefix,\n",
        "                               model_tflite)\n",
        "\n",
        "    converter.convert()\n",
        "\n",
        "except ValueError as e:\n",
        "    print(traceback.format_exc())\n",
        "    print(\"Error occurred while converting\")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "G7noTBgTg8Fz",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "files.download(model_tflite)"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}